package cn.edu.dgut.css.sai.security.oauth2.filter;

import org.apache.catalina.util.ParameterMap;
import org.springframework.security.oauth2.client.web.AuthorizationRequestRepository;
import org.springframework.security.oauth2.client.web.HttpSessionOAuth2AuthorizationRequestRepository;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames;
import org.springframework.security.web.util.UrlUtils;
import org.springframework.util.LinkedMultiValueMap;
import org.springframework.util.MultiValueMap;
import org.springframework.util.StringUtils;
import org.springframework.web.filter.OncePerRequestFilter;

import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.Map;

/**
 * @author Sai
 * @see org.springframework.security.oauth2.client.web.OAuth2LoginAuthenticationFilter
 * @see org.springframework.security.oauth2.client.web.OAuth2AuthorizationCodeGrantFilter
 * @since 1.0
 */
public final class DgutAuthorizationResponseFilter extends OncePerRequestFilter {

    private final AuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository =
            new HttpSessionOAuth2AuthorizationRequestRepository();

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws IOException, ServletException {
        // 学校中央登录的授权码的参数名为token,要改为code
        if (shouldProcessAuthorizationResponse(request) && isDgutAuthorizationResponse(request)) {
            // 修改request的参数,需先解锁
            if (request.getParameterMap().containsKey("token")) {
                String token = request.getParameter("token");
                ParameterMap<String, String[]> parameterMap = (ParameterMap<String, String[]>) request.getParameterMap();
                parameterMap.setLocked(false);
                parameterMap.remove("token");
                parameterMap.put("code", new String[]{token});
                parameterMap.setLocked(true);
            }
        }
        filterChain.doFilter(request, response);
    }

    private boolean shouldProcessAuthorizationResponse(HttpServletRequest request) {
        OAuth2AuthorizationRequest authorizationRequest = this.authorizationRequestRepository.loadAuthorizationRequest(request);
        if (authorizationRequest == null)
            return false;

        String requestUrl = UrlUtils.buildFullRequestUrl(request.getScheme(), request.getServerName(),
                request.getServerPort(), request.getRequestURI(), null);
        MultiValueMap<String, String> params = OAuth2AuthorizationResponseUtils.toMultiMap(request.getParameterMap());
        return requestUrl.equals(authorizationRequest.getRedirectUri()) &&
                OAuth2AuthorizationResponseUtils.isAuthorizationResponse(params);
    }

    private boolean isDgutAuthorizationResponse(HttpServletRequest request) {
        OAuth2AuthorizationRequest authorizationRequest =
                this.authorizationRequestRepository.loadAuthorizationRequest(request);
        if (authorizationRequest == null)
            return false;
        // 2.0.1 修复
        String registrationId = (String) authorizationRequest.getAttributes().get(OAuth2ParameterNames.REGISTRATION_ID);
        return registrationId.equals("dgut");
    }

    /**
     * 这个工具类的代码源自 {@code org.springframework.security.oauth2.client.web.OAuth2AuthorizationResponseUtils}
     */
    static final class OAuth2AuthorizationResponseUtils {

        private OAuth2AuthorizationResponseUtils() {
        }

        static MultiValueMap<String, String> toMultiMap(Map<String, String[]> map) {
            MultiValueMap<String, String> params = new LinkedMultiValueMap<>(map.size());
            map.forEach((key, values) -> {
                if (values.length > 0) {
                    for (String value : values) {
                        params.add(key, value);
                    }
                }
            });
            return params;
        }

        static boolean isAuthorizationResponse(MultiValueMap<String, String> request) {
            return isAuthorizationResponseSuccess(request) || isAuthorizationResponseError(request);
        }

        static boolean isAuthorizationResponseSuccess(MultiValueMap<String, String> request) {
            return StringUtils.hasText(request.getFirst("token")) && // 这里把源码的 code 改为 token 。
                    StringUtils.hasText(request.getFirst(OAuth2ParameterNames.STATE));
        }

        static boolean isAuthorizationResponseError(MultiValueMap<String, String> request) {
            return StringUtils.hasText(request.getFirst(OAuth2ParameterNames.ERROR)) &&
                    StringUtils.hasText(request.getFirst(OAuth2ParameterNames.STATE));
        }

    }

}
